package org.apache.spark.ml.feature


import org.apache.commons.lang3.StringUtils
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.regression.LinearRegression
import org.apache.spark.sql.SparkSession


/**
  * Created by peibin on 2017/4/13.
  */
object Weibo {

  def main(args: Array[String]): Unit = {
    val spark = SparkSession
      .builder()
      .appName("weibo")
      .master("local[*]")
      .getOrCreate()
    import spark.implicits._
    val data = spark.read.textFile("""/Users/peibin/workspace/data-mining/tianchi/weibo/weibo_train_data.txt""").flatMap(x => {
      val arr = StringUtils.split(x, '\t')
      arr.length match {
        case 7 => try {
          Some(Weibo(arr(0), arr(1), arr(2), arr(3).toInt, arr(4).toInt, arr(5).toInt, arr(6)))
        } catch {
          case _ => None
        }
        case _ => None
      }
    })
    data.show(10)

    val Array(train, test) = data.randomSplit(Array(0.6, 0.4))

    val extractor = new TianchiWeiboExtractor().setInputCol("content").setOutputCol("features")


    val hashingTF = new HashingTF()
      .setInputCol("words").setOutputCol("rawFeatures")
    //.setNumFeatures(20)

    val idf = new IDF().setInputCol("rawFeatures").setOutputCol("features")

    val lr1 = new LinearRegression()
      .setFeaturesCol("features")
      .setLabelCol("forwardCount")
      .setPredictionCol("fowardPredict")
    val lr2 = new LinearRegression()
      .setFeaturesCol("features")
      .setLabelCol("commentCount")
      .setPredictionCol("commentPredict")
    val lr3 = new LinearRegression()
      .setFeaturesCol("features")
      .setLabelCol("likeCount")
      .setPredictionCol("likePredict")

    val pipeline = new Pipeline().setStages(Array(extractor, lr1, lr2, lr3));
    val model = pipeline.fit(train)
    model.transform(test).show(10)


  }

  case class Weibo(uid: String, mid: String, time: String, forwardCount: Int, commentCount: Int, likeCount: Int, content: String)


}
